package com.lauor.smpedr.core.helper;

import com.lauor.smpedr.core.anno.FieldName;
import com.lauor.smpedr.core.anno.Table;
import com.lauor.smpedr.param.OptEnum;
import com.lauor.smpedr.param.SqlArgMap;
import com.lauor.smpedr.utils.Str;

import java.lang.reflect.Field;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.*;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
 * edr 持久化帮助类
 */
public class EdrHelper {
    /**
     * @description:获取类对应的数据表名
     * @param cls
     * @return java.lang.String
     */
    public static String getTableName(Class cls){
        Table tableAno = (Table) cls.getAnnotation(Table.class);
        if (tableAno != null && !Str.isEmpty(tableAno.name())){
            return tableAno.name();
        }
        return cls.getSimpleName();
    }

    public static Set<String> findFieldsByAnt(Class clazz, Class annoClazz){
        return findFields(clazz, field -> field.isAnnotationPresent(annoClazz) );
    }
    /**
     * @description:按照给定逻辑找出属性名
     * @param clazz
     * @param fieldVisitor 自定义处理逻辑，访问者,参数为属性列
     * @exception
     * @return java.util.Set<java.lang.reflect.Field>
     */
    public static Set<String> findFields(Class clazz, Function<Field, Boolean> fieldVisitor){
        Field[] fields = clazz.getDeclaredFields();
        Set<String> fieldNameSet = new HashSet<>();
        for (Field field : fields) {
            if ( fieldVisitor.apply(field) ){
                fieldNameSet.add(field.getName());
            }
        }
        return fieldNameSet;
    }
    /**
     * @description:获取属性对应的数据表字段名
     * @param field 属性 not null
     * @return java.lang.String
     */
    public static String getDbFieldName(Field field){
        if (field == null) {
            return "";
        }

        field.setAccessible(true);
        FieldName anno = field.getDeclaredAnnotation(FieldName.class);
        return anno == null || Str.isEmpty( anno.value() ) ? field.getName() : anno.value();
    }
    /**sql part begin*/
    /**
     * @description:生成sql占位符，防注入模式
     * @param fieldName
     * @return java.lang.String
     */
    public static String generateSqlPlaceHolder(String fieldName){
        return new StringBuilder("#{")
                .append(fieldName)
                .append("}").toString();
    }
    /**
     * @description:用?取代sql占位符(防注入模式)
     * @param sql
     * @return java.lang.String
     */
    public static String replaceSqlPlaceHolder(String sql){
        return Str.isNull(sql) ? sql : sql.replaceAll("#\\{.*?}", "?");
    }
    /**
     * @description:根据sql中参数占位符顺序获取参数值
     * @param sql
     * @param sqlArg
     * @return java.util.List
     */
    public static List getParamsSortedBySql(String sql, SqlArgMap sqlArg){
        if ( Str.isNull(sql) || sqlArg == null || sqlArg.isEmpty()) return Collections.EMPTY_LIST;

        Pattern pattern = Pattern.compile("#\\{(.*?)}");
        Matcher matcher = pattern.matcher(sql);
        //顺序存储参数值
        List paramList = new ArrayList( sqlArg.size() );
        int addParamNums = 0, loopNums = 0;
        while ( matcher.find() ){
            if (loopNums++ < addParamNums) continue;

            String properName = matcher.group(1).trim();
            SqlArgMap.ArgNode paramNode = sqlArg.getValue(properName);
            if (paramNode == null){
                paramList.add(null);
                continue;
            }
            //集合类型
            if (OptEnum.EQ.isParamCollection( paramNode.getOpt() ) && (paramNode.getValue() instanceof Collection)){
                Collection paramValList = (Collection) paramNode.getValue();
                paramList.addAll(paramValList);
                addParamNums += paramValList.size();
            } else {
                paramList.add( paramNode.getValue() );
                addParamNums++;
            }
        }
        return paramList;
    }
    /**
     * @description:从插入语句获取参数集合
     * @param sql
     * @param dataList
     * @return java.util.List<java.util.List>
     */
    public static <E> List<List> getInsertParamsSortedBySql(String sql, List<E> dataList){
        if ( Str.isNull(sql) || dataList == null || dataList.isEmpty()) return Collections.EMPTY_LIST;

        //从sql中找出要用的属性
        Pattern pattern = Pattern.compile("#\\{(.*?)}");
        Matcher matcher = pattern.matcher(sql);
        List<String> fieldsInSql = new LinkedList<>();
        while ( matcher.find() ){
            String fieldName = matcher.group(1).trim();
            if ( fieldsInSql.contains(fieldName) ) break;

            fieldsInSql.add(fieldName);
        }
        if ( fieldsInSql.isEmpty() ) return Collections.EMPTY_LIST;

        Class clazz = dataList.get(0).getClass();
        List<Method> getterMethods = new LinkedList<>();
        for (String fieldName : fieldsInSql) {
            getterMethods.add( OrmHelper.getGetterMethodByField(fieldName, clazz) );
        }
        //开始组装参数值
        List<List> rsList = new ArrayList<>( dataList.size() );
        for (E data : dataList) {
            List item = new ArrayList();
            for (Method method : getterMethods) {
                if (method == null){
                    item.add(null);
                } else {
                    method.setAccessible(true);
                    Object value = null;
                    try {
                        value = method.invoke(data);
                    } catch (IllegalAccessException | InvocationTargetException e) {}
                    item.add(value);
                }
            }
            rsList.add(item);
        }
        return rsList;
    }
    /**
     * @description:获取sqlwhere部分，通过 where 分割
     * @param sql
     * @return java.lang.String
     */
    public static String getSqlWherePart(String sql){
        if ( Str.isEmpty(sql) ) return sql;

        //通过where 切割sql
        String[] sqlPartArr = sql.split(" where ");
        if (sqlPartArr.length < 2) return "";

        StringBuilder whereBuilder = new StringBuilder();
        for (int i = 1; i < sqlPartArr.length; i++) {
            whereBuilder.append( sqlPartArr[i] );
        }
        return whereBuilder.toString();
    }
    /**sql part end*/

    /**resultSet part begin*/
    /**
     * @description:获取数据库字段名与对象字段名之间的映射关系
     * @param cls
     * @return java.util.Map<java.lang.String,java.lang.String>
     */
    public static Map<String, String> getFieldDbRelateEntity(Class cls){
        Field[] fields = cls.getDeclaredFields();
        Map<String, String> result = new HashMap<>();
        for (Field field : fields) {
            field.setAccessible(true);
            FieldName anno = field.getDeclaredAnnotation(FieldName.class);
            String fieldName = field.getName();
            String dbFieldName = anno == null || Str.isNull( anno.value() ) ? fieldName : anno.value();
            result.put(dbFieldName, fieldName);
        }
        return result;
    }
    /**resultSet part end*/
}
